# Forward sampling for a continuous-time variable C -> D
rm(list = ls())
library(LaplacesDemon)
library(rstan)
library(erer)
generateComp <- function(N, D.tran, q1, q2)
{
  set.seed(1987)
  A.traj = c() # time trajectory for discrete-time variable
  B.traj = c() # time trajectory for continuous-time variable
  # sampled states
  D.states = c()
  C.states = c()
  # CPTs for D and C, no arcs in inital BN
  D.theta = c(0.2, 0.8)
  C.theta = c(0.4, 0.6)

  inten = as.data.frame(matrix(c(-q1,q1,q2,-q2), ncol = 2, byrow = TRUE))
  
  D.currstate = rcat(1, D.theta)
  C.currstate = rcat(1, C.theta)
  D.states = c(D.states, D.currstate)
  C.states = c(C.states, C.currstate)
  currTime = 0
  time.traj = c()
  B.traj = c()
  time.int = c()
  # generate complete trajectory
  while(currTime <=N)
  {
    time.traj = c(time.traj, currTime)
    B.traj = c(B.traj, currTime)
    # choose rate 
    q = -inten[C.currstate, C.currstate]
    # generate next time
    time = rexp(1, rate = q)
    time.int = c(time.int, time)
    # reject after next discrete-time slice
    C.currstate = setdiff(c(1,2), tail(C.states, 1))
    C.states = c(C.states, C.currstate)
    currTime = currTime + time
  }
  
  # generate D given C
  C.ind = c()
  for(i in 2:N)
  {
    m = max(which(B.traj <= i))
    C.ind = c(C.ind, m)
    trandis = D.tran[C.states[m],,D.currstate]
    D.currstate = rcat(1, trandis)
    D.states = c(D.states, D.currstate)
  }
  
  incomp = list(time.traj = time.traj, C.states = C.states, D.tran.states = D.states)
  comp = list(D.C.states = C.states[C.ind], D.states = D.states, C.states = C.states, time.int = time.int)
  return(list(incomp = incomp, comp = comp))
}

estimatePar <- function(incomp, q1, q2, D.tran, rate, N)
{
time.traj = incomp$time.traj
C.states = incomp$C.states
D.tran.states = incomp$D.tran.states

# generate missing sequences
# D.C.states = C.states[C.ind]
# D.tran.states = D.states

# generate time-point evidence
rtp = c()
rtp_sys = c(0)
rstates = c()
sys_time = 0
time = rexp(1, rate = rate)
sys_time = sys_time + time
# generate observations for continuous-time variable C
while(sys_time <= max(time.traj))
{
  rtp_sys = c(rtp_sys, sys_time)
  rtp = c(rtp, time)
  time = rexp(1, rate = rate)
  sys_time = sys_time + time
}

for(i in 1:(length(time.traj)-1))
{
  t = which(rtp_sys>= (time.traj[i] ) & rtp_sys<= (time.traj[i+1]))
  rstates = c(rstates, rep(C.states[i], length(t)))
}

# generate observations for discrete-time variable D
ind = c()
D.C.states = c()
D.time = c()
for(i in 2:N)
{
  m = max(which(rtp_sys <= i))
  ind = c(ind, m)
  D.time = c(D.time, i-rtp_sys[m])
  D.C.states = c(D.C.states, rstates[m])
}

padeC=rbind(c(120, 60, 12, 1, 0, 0, 0, 0, 0, 0), 
            c(30240, 15120, 3360, 420, 30, 1, 0, 0, 0, 0), 
            c(17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1, 0, 0), 
            c(17643225600, 8821612800, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1))
padeCbig= c(64764752532480000, 32382376266240000, 7771770303897600,
            1187353796428800, 129060195264000, 10559470521600,
            670442572800, 33522128640, 1323241920, 40840800,
            960960, 16380, 182, 1)

data = list(padeC = padeC, padeCbig = padeCbig, states = rstates, time = rtp, N = length(rstates), 
            D_states = D.tran.states, D_C_states = D.C.states, D_N = length(D.tran.states), D_time = D.time) 

fit = stan("C->D(C missing).stan", data = data,  iter = 1000, chain = 1)
la = extract(fit)

# intensities for continuous-time variable
q = hist(la$inten[,1], breaks = 20)
est.q1 = q$mids[which(q$counts ==  max(q$counts))][[1]]
q = hist(la$inten[,2], breaks = 20)
est.q2 = q$mids[which(q$counts ==  max(q$counts))][[1]]
cat("estimated intensity1 = " ,est.q1, "original = ", q1,"\n")
cat("estimated intensity2 = " ,est.q2, "original = ", q2,"\n")

# CPTs for discrete-time variables
# D_CPT[1,2] = la$D_CPT[,1,2,]
# P(D_t+1=1| C_t+1 = 1,D_t =2)
M = array(data = rep(0,8), dim = c(2,2,2))

for(i in 1:2)
{
  for(j in 1:2)
  {
    for(k in 1:2)
    {
      q = hist(la$D_CPT[,i,k,j], breaks = 20)
      M[i,j,k] = q$mids[which(q$counts ==  max(q$counts))][[1]]
    }
  }
}
q = hist(la$D_CPT[,1,1,1], breaks = 20)
est11= q$mids[which(q$counts ==  max(q$counts))][[1]]
q = hist(la$D_CPT[,1,2,1], breaks = 20)
est21 = q$mids[which(q$counts ==  max(q$counts))][[1]]
q = hist(la$D_CPT[,2,1,1], breaks = 20)
est12= q$mids[which(q$counts ==  max(q$counts))][[1]]
q = hist(la$D_CPT[,2,2,1], breaks = 20)
est22= q$mids[which(q$counts ==  max(q$counts))][[1]]

q = hist(la$D_CPT[,1,1,2], breaks = 20)
est11= q$mids[which(q$counts ==  max(q$counts))][[1]]
q = hist(la$D_CPT[,1,2,2], breaks = 20)
est21 = q$mids[which(q$counts ==  max(q$counts))][[1]]
q = hist(la$D_CPT[,2,1,2], breaks = 20)
est12= q$mids[which(q$counts ==  max(q$counts))][[1]]
q = hist(la$D_CPT[,2,2,2], breaks = 20)
est22= q$mids[which(q$counts ==  max(q$counts))][[1]]

inten = c(est.q1, est.q2)
return(list(inten = inten, CPT = M))
}

run <- function(rate)
{
  q1 = rgamma(1, shape = 2, rate = 2)
  q2 = rgamma(1, shape = 2, rate = 2)
  org_inten = c(q1,q2)
  D.tran = array(data =rep(0,8), dim = c(2,2,2))
  D.tran[,,1] = rdirichlet(2,c(1,1))
  D.tran[,,2] = rdirichlet(2,c(1,1))
  N.slice = c(10, 20, 40, 80, 160, 320, 640, 1024, 2048, 4096)
  # evaluation on a new dataset
  seq2 = generateComp(10000, D.tran, q1, q2)
  est.ll = c()
  true.ll = c()
  true.inten = list()
  est.inten = list()
  true.D.tran = list()
  est.D.tran = list()
  for(k in 1:(length(N.slice)))
  {
    N = N.slice[k]
    cat("Now the rate:", rate, " length: ", N)
    seq = generateComp(N, D.tran, q1, q2)
    incomp = seq$incomp
    
    par = estimatePar(incomp, q1, q2, D.tran, rate, N)
    inten = par$inten
    CPT = par$CPT
    
    # test on a new dataset
    comp = seq2$comp
    D.C.states = comp$D.C.states
    D.states = comp$D.states
    C.states = comp$C.states
    time.int = comp$time.int
    # number of each instance for discrete-time  variables
    M = array(data = rep(0,8), dim = c(2,2,2))
    for(i in 1:(length(D.states)-1))
    {
      M[D.C.states[i],D.states[i+1],D.states[i]] =   M[D.C.states[i],D.states[i+1],D.states[i]] + 1
    }
    
    log_dis_org = sum(M*log(D.tran))
    log_dis = sum(M*log(CPT))
    
    # number of transition for continuous-time variable
    M = array(data = rep(0,4), dim = c(2,2))
    Dur = c(rep(0,2))
    for(i in 1:(length(C.states)-1))
    {
      M[C.states[i],C.states[i+1]] =  M[C.states[i],C.states[i+1]] + 1
      Dur[C.states[i]] = Dur[C.states[i]] + time.int[i]
    }
    
    log_con_org = M[1,2] * log(q1) - q1*Dur[1] +
                  M[2,1] * log(q2) - q2*Dur[2]
    log_con = M[1,2] * log(inten[1]) - inten[1]*Dur[1] +
              M[2,1] * log(inten[2]) - inten[2]*Dur[2]
            
    log_tot_org = log_dis_org + log_con_org
    log_tot = log_dis + log_con
    est.ll = c(est.ll, log_tot)
    true.ll = c(true.ll, log_tot_org)
    
    true.D.tran[[length(true.D.tran) + 1]] = D.tran
    est.D.tran[[length(est.D.tran) + 1]] = CPT
    true.inten[[length(true.inten) + 1]] = org_inten
    est.inten[[length(est.inten) + 1]] = inten
  }
  LL.mis = data.frame(N.slice = log10(N.slice), est.ll = est.ll, true.ll = true.ll,  diff = abs(true.ll - est.ll))
  est.inten.data = matrix(unlist(est.inten),ncol =2, byrow = T)
  colnames(est.inten.data) = c("q1", "q2")
  true.inten.data = matrix(unlist(true.inten),ncol =2, byrow = T)
  colnames(true.inten.data) = c("q1", "q2")
  write.csv(LL.mis, paste("~/Papers/LearnHTBN/CD(mis)_rate", rate,".LL.diff.csv", sep = ""), row.names = FALSE, quote = FALSE)
  write.list(true.D.tran, paste("~/Papers/LearnHTBN/CD(mis)_rate", rate,".true.CPT.csv", sep = ""), quote = FALSE, row.names = FALSE, t.name = N.slice, eol = "\n")
  write.list(est.D.tran, paste("~/Papers/LearnHTBN/CD(mis)_rate", rate,".est.CPT.csv", sep = ""), quote = FALSE, row.names = FALSE, t.name = N.slice, eol = "\n")
  write.csv(est.inten.data, paste("~/Papers/LearnHTBN/CD(mis)_rate", rate,".est.inten.csv", sep = ""), quote = FALSE, row.names = FALSE)
  write.csv(true.inten.data, paste("~/Papers/LearnHTBN/CD(mis)_rate", rate,".true.inten.csv", sep = ""), quote = FALSE, row.names = FALSE)
  return(list(true.ll = true.ll, est.ll = est.ll))
}




